{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# HybridConditional"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "license_cell",
      "metadata": {
        "tags": [
          "remove-cell"
        ]
      },
      "source": [
        "GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,\nAtlanta, Georgia 30332-0415\nAll Rights Reserved\n\nAuthors: Frank Dellaert, et al. (see THANKS for the full author list)\n\nSee LICENSE for the license information"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<a href=\"https://colab.research.google.com/github/borglab/gtsam/blob/develop/gtsam/hybrid/doc/HybridConditional.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "9cb314a6",
      "metadata": {
        "tags": [
          "remove-cell"
        ]
      },
      "outputs": [],
      "source": [
        "try:\n",
        "    import google.colab\n",
        "    %pip install --quiet gtsam-develop\n",
        "except ImportError:\n",
        "    pass  # Not running on Colab, do nothing"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "94f2d735",
      "metadata": {},
      "source": [
        "`HybridConditional` acts as a **type-erased wrapper** for different kinds of conditional distributions that can appear in a `HybridBayesNet` or `HybridBayesTree`. It allows these containers to hold conditionals resulting from eliminating different types of variables (discrete, continuous, or mixtures) without needing to be templated on the specific conditional type.\n",
        "\n",
        "A `HybridConditional` object internally holds a shared pointer to one of the following concrete conditional types:\n",
        "*   `gtsam.GaussianConditional`\n",
        "*   `gtsam.DiscreteConditional`\n",
        "*   `gtsam.HybridGaussianConditional`\n",
        "\n",
        "It inherits from `HybridFactor` and `Conditional<HybridFactor, HybridConditional>`, providing access to both factor-like properties (keys) and conditional-like properties (frontals, parents).\n",
        "\n",
        "```mermaid\n",
        "graph TD\n",
        "    HybridConditional --> HybridFactor\n",
        "    HybridConditional --> Conditional\n",
        "    HybridConditional -- Holds shared pointer to --> GaussianConditional\n",
        "    HybridConditional -- Holds shared pointer to --> DiscreteConditional\n",
        "    HybridConditional -- Holds shared pointer to --> HybridGaussianConditional\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "6e6df071",
      "metadata": {},
      "outputs": [],
      "source": [
        "import gtsam\n",
        "import numpy as np\n",
        "\n",
        "from gtsam import (\n",
        "    GaussianConditional,\n",
        "    DiscreteConditional,\n",
        "    HybridConditional,\n",
        "    HybridGaussianConditional,\n",
        ")\n",
        "from gtsam.symbol_shorthand import X, D"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "23447d16",
      "metadata": {},
      "source": [
        "## Initialization\n",
        "\n",
        "A `HybridConditional` is created by wrapping a shared pointer to one of the concrete conditional types. These concrete conditionals are usually obtained from factor graph elimination."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "f5ff356f",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "HybridConditional from GaussianConditional:\n",
            "Hybrid Conditional\n",
            "p(x0 | x1)\n",
            "  R = [ 2 ]\n",
            "  S[x1] = [ 0.5 ]\n",
            "  d = [ 1 ]\n",
            "  logNormalizationConstant: 0.467356\n",
            "isotropic dim=1 sigma=0.5\n",
            "\n",
            "HybridConditional from DiscreteConditional:\n",
            "Hybrid Conditional\n",
            " P( d0 | d1 ):\n",
            " Choice(d1) \n",
            " 0 Choice(d0) \n",
            " 0 0 Leaf  0.8\n",
            " 0 1 Leaf  0.2\n",
            " 1 Choice(d0) \n",
            " 1 0 Leaf  0.2\n",
            " 1 1 Leaf  0.8\n",
            "\n",
            "\n",
            "HybridConditional from HybridGaussianConditional:\n",
            "Hybrid Conditional\n",
            " P( x2 | d2)\n",
            " Discrete Keys = (d2, 2), \n",
            " logNormalizationConstant: 0.467356\n",
            "\n",
            " Choice(d2) \n",
            " 0 Leaf p(x2)\n",
            "  R = [ 1 ]\n",
            "  d = [ 0 ]\n",
            "  mean: 1 elements\n",
            "  x2: 0\n",
            "  logNormalizationConstant: -0.918939\n",
            "  Noise model: unit (1) \n",
            "\n",
            " 1 Leaf p(x2)\n",
            "  R = [ 2 ]\n",
            "  d = [ 10 ]\n",
            "  mean: 1 elements\n",
            "  x2: 5\n",
            "  logNormalizationConstant: 0.467356\n",
            "isotropic dim=1 sigma=0.5\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# --- Create concrete conditionals (examples) ---\n",
        "# 1. GaussianConditional P(X0 | X1)\n",
        "gc = GaussianConditional(X(0), np.array([1.0]), np.eye(1)*2.0, # d, R\n",
        "                         X(1), np.array([[0.5]]),                # Parent, S\n",
        "                         gtsam.noiseModel.Diagonal.Sigmas([0.5])) # sigma=0.5 -> prec=4 -> R=2\n",
        "\n",
        "# 2. DiscreteConditional P(D0 | D1) (D0, D1 binary)\n",
        "dk0 = (D(0), 2)\n",
        "dk1 = (D(1), 2)\n",
        "dc = DiscreteConditional(dk0, [dk1], \"4/1 1/4\") # P(D0|D1=0) = 80/20, P(D0|D1=1) = 20/80\n",
        "\n",
        "# 3. HybridGaussianConditional P(X2 | D2) (X2 1D, D2 binary)\n",
        "dk2 = (D(2), 2)\n",
        "# Mode 0: P(X2 | D2=0) = N(0, 1) -> R=1, d=0\n",
        "hgc_gc0 = GaussianConditional(X(2), np.zeros(1), np.eye(1), gtsam.noiseModel.Unit.Create(1))\n",
        "# Mode 1: P(X2 | D2=1) = N(5, 0.25) -> R=2, d=10\n",
        "hgc_gc1 = GaussianConditional(X(2), np.array([10.0]), np.eye(1)*2.0, gtsam.noiseModel.Isotropic.Sigma(1,0.5))\n",
        "# This constructor takes vector of conditionals directly if parents match\n",
        "hgc = HybridGaussianConditional(dk2, [hgc_gc0, hgc_gc1])\n",
        "\n",
        "# --- Wrap them into HybridConditionals ---\n",
        "hybrid_cond_g = HybridConditional(gc)\n",
        "hybrid_cond_d = HybridConditional(dc)\n",
        "hybrid_cond_h = HybridConditional(hgc)\n",
        "\n",
        "print(\"HybridConditional from GaussianConditional:\")\n",
        "hybrid_cond_g.print()\n",
        "print(\"\\nHybridConditional from DiscreteConditional:\")\n",
        "hybrid_cond_d.print()\n",
        "print(\"\\nHybridConditional from HybridGaussianConditional:\")\n",
        "hybrid_cond_h.print()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "da85d70c",
      "metadata": {},
      "source": [
        "## Accessing Information and Inner Type\n",
        "\n",
        "You can access keys, frontals, and parents like any conditional. You can also check the underlying type and attempt to cast back to the concrete type."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "885418ff",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "--- Inspecting HybridConditional from Gaussian ---\n",
            "Keys: [8646911284551352320, 8646911284551352321]\n",
            "Frontals: 1\n",
            "Parents: 1\n",
            "Is Continuous? True\n",
            "Is Discrete? False\n",
            "Is Hybrid? False\n",
            "Successfully cast back to GaussianConditional:\n",
            "GaussianConditional p(x0 | x1)\n",
            "  R = [ 2 ]\n",
            "  S[x1] = [ 0.5 ]\n",
            "  d = [ 1 ]\n",
            "  logNormalizationConstant: 0.467356\n",
            "isotropic dim=1 sigma=0.5\n",
            "Cast back to DiscreteConditional successful? False\n",
            "\n",
            "--- Inspecting HybridConditional from Hybrid ---\n",
            "Keys: [8646911284551352322, 7205759403792793602]\n",
            "Frontals: 1\n",
            "Parents: 1\n",
            "Continuous Keys: [8646911284551352322]\n",
            "Discrete Keys: \n",
            "d2 2\n",
            "\n",
            "Is Continuous? False\n",
            "Is Discrete? False\n",
            "Is Hybrid? True\n",
            "Successfully cast back to HybridGaussianConditional.\n"
          ]
        }
      ],
      "source": [
        "print(\"\\n--- Inspecting HybridConditional from Gaussian ---\")\n",
        "print(f\"Keys: {hybrid_cond_g.keys()}\")\n",
        "print(f\"Frontals: {hybrid_cond_g.nrFrontals()}\")\n",
        "print(f\"Parents: {hybrid_cond_g.nrParents()}\")\n",
        "print(f\"Is Continuous? {hybrid_cond_g.isContinuous()}\") # True\n",
        "print(f\"Is Discrete? {hybrid_cond_g.isDiscrete()}\")   # False\n",
        "print(f\"Is Hybrid? {hybrid_cond_g.isHybrid()}\")     # False\n",
        "\n",
        "# Try casting back\n",
        "inner_gaussian = hybrid_cond_g.asGaussian()\n",
        "if inner_gaussian:\n",
        "    print(\"Successfully cast back to GaussianConditional:\")\n",
        "    inner_gaussian.print()\n",
        "else:\n",
        "    print(\"Failed to cast back to GaussianConditional.\")\n",
        "\n",
        "inner_discrete = hybrid_cond_g.asDiscrete()\n",
        "print(f\"Cast back to DiscreteConditional successful? {inner_discrete is not None}\")\n",
        "\n",
        "print(\"\\n--- Inspecting HybridConditional from Hybrid ---\")\n",
        "print(f\"Keys: {hybrid_cond_h.keys()}\")\n",
        "print(f\"Frontals: {hybrid_cond_h.nrFrontals()}\")\n",
        "print(f\"Parents: {hybrid_cond_h.nrParents()}\") # Contains continuous AND discrete parents\n",
        "print(f\"Continuous Keys: {hybrid_cond_h.continuousKeys()}\")\n",
        "print(f\"Discrete Keys: {hybrid_cond_h.discreteKeys()}\")\n",
        "print(f\"Is Continuous? {hybrid_cond_h.isContinuous()}\") # False\n",
        "print(f\"Is Discrete? {hybrid_cond_h.isDiscrete()}\")   # False\n",
        "print(f\"Is Hybrid? {hybrid_cond_h.isHybrid()}\")     # True\n",
        "\n",
        "# Try casting back\n",
        "inner_hybrid = hybrid_cond_h.asHybrid()\n",
        "if inner_hybrid:\n",
        "    print(\"Successfully cast back to HybridGaussianConditional.\")\n",
        "else:\n",
        "    print(\"Failed to cast back to HybridGaussianConditional.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "23fc9fc6",
      "metadata": {},
      "source": [
        "## Evaluation (`error`, `logProbability`, `evaluate`)\n",
        "\n",
        "These methods delegate to the underlying concrete conditional's implementation. They require a `HybridValues` object containing assignments for all involved variables (frontal and parents)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "ce6716ae",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Gaussian HybridConditional P(X0=2|X1=1):\n",
            "  Error: 24.5\n",
            "  LogProbability: -24.032644172084783\n",
            "  Probability: 3.6538881633458336e-11\n",
            "\n",
            "Discrete HybridConditional P(D0=1|D1=0):\n",
            "  Error: 1.6094379124341003\n",
            "  LogProbability: -1.6094379124341003\n",
            "  Probability: 0.2\n",
            "\n",
            "Hybrid Gaussian HybridConditional P(X2=4.5|D2=1):\n",
            "  Error: 2.0\n",
            "  LogProbability: -1.5326441720847823\n",
            "  Probability: 0.21596386605275217\n"
          ]
        }
      ],
      "source": [
        "# --- Evaluate the Gaussian Conditional P(X0 | X1) ---\n",
        "vals_g = gtsam.HybridValues()\n",
        "vals_g.insert(X(0), np.array([2.0])) # Frontal\n",
        "vals_g.insert(X(1), np.array([1.0])) # Parent\n",
        "\n",
        "err_g = hybrid_cond_g.error(vals_g)\n",
        "log_prob_g = hybrid_cond_g.logProbability(vals_g)\n",
        "prob_g = hybrid_cond_g.evaluate(vals_g) # Equivalent to exp(logProbability)\n",
        "\n",
        "print(f\"\\nGaussian HybridConditional P(X0=2|X1=1):\")\n",
        "print(f\"  Error: {err_g}\")\n",
        "print(f\"  LogProbability: {log_prob_g}\")\n",
        "print(f\"  Probability: {prob_g}\")\n",
        "\n",
        "# --- Evaluate the Discrete Conditional P(D0 | D1) ---\n",
        "vals_d = gtsam.HybridValues()\n",
        "vals_d.insert(D(0), 1) # Frontal = 1\n",
        "vals_d.insert(D(1), 0) # Parent = 0\n",
        "\n",
        "err_d = hybrid_cond_d.error(vals_d) # -log(P(D0=1|D1=0)) = -log(0.2)\n",
        "log_prob_d = hybrid_cond_d.logProbability(vals_d) # log(0.2)\n",
        "prob_d = hybrid_cond_d.evaluate(vals_d)         # 0.2\n",
        "\n",
        "print(f\"\\nDiscrete HybridConditional P(D0=1|D1=0):\")\n",
        "print(f\"  Error: {err_d}\")\n",
        "print(f\"  LogProbability: {log_prob_d}\")\n",
        "print(f\"  Probability: {prob_d}\")\n",
        "\n",
        "# --- Evaluate the Hybrid Gaussian Conditional P(X2 | D2) ---\n",
        "vals_h = gtsam.HybridValues()\n",
        "vals_h.insert(X(2), np.array([4.5])) # Frontal\n",
        "vals_h.insert(D(2), 1)             # Parent (selects mode 1: N(5, 0.25))\n",
        "\n",
        "err_h = hybrid_cond_h.error(vals_h)\n",
        "log_prob_h = hybrid_cond_h.logProbability(vals_h)\n",
        "prob_h = hybrid_cond_h.evaluate(vals_h)\n",
        "\n",
        "print(f\"\\nHybrid Gaussian HybridConditional P(X2=4.5|D2=1):\")\n",
        "print(f\"  Error: {err_h}\")\n",
        "print(f\"  LogProbability: {log_prob_h}\")\n",
        "print(f\"  Probability: {prob_h}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ac079490",
      "metadata": {},
      "source": [
        "## Restriction (`restrict`)\n",
        "\n",
        "The `restrict` method allows fixing the discrete parent variables, potentially simplifying the conditional (e.g., a `HybridGaussianConditional` might become a `GaussianConditional`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "6f6a5dab",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Restricted HybridConditional (D2=1):p(x2)\n",
            "  R = [ 2 ]\n",
            "  d = [ 10 ]\n",
            "  mean: 1 elements\n",
            "  x2: 5\n",
            "  logNormalizationConstant: 0.467356\n",
            "isotropic dim=1 sigma=0.5\n"
          ]
        }
      ],
      "source": [
        "# Restrict the HybridGaussianConditional P(X2 | D2)\n",
        "assignment = gtsam.DiscreteValues()\n",
        "assignment[D(2)] = 1 # Fix D2 to mode 1\n",
        "\n",
        "restricted_factor = hybrid_cond_h.restrict(assignment)\n",
        "\n",
        "restricted_factor.print(\"\\nRestricted HybridConditional (D2=1):\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "py312",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
